{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c24e4a0f-30dd-4cbf-be1f-7bb2a0ab269f",
   "metadata": {},
   "source": [
    "# EmotionPrompt in RAG\n",
    "\n",
    "Inspired by the \"[Large Language Models Understand and Can Be Enhanced by\n",
    "Emotional Stimuli](https://arxiv.org/pdf/2307.11760.pdf)\" by Li et al., in this guide we show you how to evaluate the effects of emotional stimuli on your RAG pipeline:\n",
    "\n",
    "1. Setup the RAG pipeline with a basic vector index with the core QA template.\n",
    "2. Create some candidate stimuli (inspired by Fig. 2 of the paper)\n",
    "3. For each candidate stimulit, prepend to QA prompt and evaluate.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e619cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-llms-openai\n",
    "%pip install llama-index-readers-file pymupdf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1edda09-993e-46b4-a353-1bacf36e0115",
   "metadata": {},
   "source": [
    "## Setup Data\n",
    "\n",
    "We use the Llama 2 paper as the input data source for our RAG pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20102a1b-7747-466f-8f41-e6cfff55c194",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p llama_2_data && wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"llama_2_data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38cc093a-b83f-4c6b-96b6-191c696b9a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.readers.file import PyMuPDFReader\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "\n",
    "docs0 = PyMuPDFReader().load_data(\"./llama_2_data/llama2.pdf\")\n",
    "\n",
    "# combine all documents into one\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "docs = [Document(text=doc_text)]\n",
    "\n",
    "# split the document into chunks of 1024 tokens\n",
    "node_parser = SentenceSplitter(chunk_size=1024)\n",
    "base_nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f43bde13-f1a2-4c74-ba7c-562f11d11004",
   "metadata": {},
   "source": [
    "## Setup Vector Index over this Data\n",
    "\n",
    "We load this data into an in-memory vector store (embedded with OpenAI embeddings).\n",
    "\n",
    "We'll be aggressively optimizing the QA prompt for this RAG pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab68750d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d29f0ff-3a5c-4102-867c-2289b9aee617",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Settings\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "\n",
    "\n",
    "Settings.llm = OpenAI(model=\"gpt-4o-mini\")\n",
    "Settings.embed_model = OpenAIEmbedding(model=\"text-embedding-3-small\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be2599b-347e-48a1-9573-009d2478fce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "index = VectorStoreIndex(base_nodes)\n",
    "\n",
    "query_engine = index.as_query_engine(similarity_top_k=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87e6adc9-7485-4aa2-b346-a47dacfbfb44",
   "metadata": {},
   "source": [
    "## Evaluation Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4404f265-304c-420f-9945-512e07e2dc1b",
   "metadata": {},
   "source": [
    "#### Golden Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca8f3198-f791-492a-99ca-6cd32d5d6a01",
   "metadata": {},
   "source": [
    "Here we load in a \"golden\" dataset.\n",
    "\n",
    "**NOTE**: We pull this in from Dropbox. For details on how to generate a dataset please see our `DatasetGenerator` module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bf4daac-1009-408c-b608-ab6448076105",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget \"https://www.dropbox.com/scl/fi/fh9vsmmm8vu0j50l3ss38/llama2_eval_qr_dataset.json?rlkey=kkoaez7aqeb4z25gzc06ak6kb&dl=1\" -O llama2_eval_qr_dataset.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f523221-f756-4d05-86fc-03806f1561b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.evaluation import QueryResponseDataset\n",
    "\n",
    "# optional\n",
    "eval_dataset = QueryResponseDataset.from_json(\"./llama2_eval_qr_dataset.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57e4c307-6e2e-4fce-a1f8-3bd87df3b6b0",
   "metadata": {},
   "source": [
    "#### Get Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60d47f9-9bbc-4cad-9487-c61953680642",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.evaluation import CorrectnessEvaluator, BatchEvalRunner\n",
    "\n",
    "\n",
    "evaluator_c = CorrectnessEvaluator()\n",
    "\n",
    "evaluator_dict = {\"correctness\": evaluator_c}\n",
    "batch_runner = BatchEvalRunner(evaluator_dict, workers=2, show_progress=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc973088-598e-4d17-a24e-659393f6d412",
   "metadata": {},
   "source": [
    "#### Define Correctness Eval Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fbe75e-5fa3-4db4-b3dc-6e3f3c1c34dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from llama_index.core.evaluation.eval_utils import aget_responses\n",
    "\n",
    "\n",
    "async def get_correctness(query_engine, eval_qa_pairs, batch_runner):\n",
    "    # then evaluate\n",
    "    # TODO: evaluate a sample of generated results\n",
    "    eval_qs = [q for q, _ in eval_qa_pairs]\n",
    "    eval_answers = [a for _, a in eval_qa_pairs]\n",
    "    pred_responses = await aget_responses(\n",
    "        eval_qs, query_engine, show_progress=True\n",
    "    )\n",
    "\n",
    "    eval_results = await batch_runner.aevaluate_responses(\n",
    "        eval_qs, responses=pred_responses, reference=eval_answers\n",
    "    )\n",
    "    avg_correctness = np.array(\n",
    "        [r.score for r in eval_results[\"correctness\"]]\n",
    "    ).mean()\n",
    "    return avg_correctness"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20e80faf-9b14-43f6-b8db-babffbc6c285",
   "metadata": {},
   "source": [
    "## Try Out Emotion Prompts\n",
    "\n",
    "We pul some emotion stimuli from the paper to try out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7de3f22-3dfd-409f-96a2-0ee0619c0b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "emotion_stimuli_dict = {\n",
    "    \"ep01\": \"Write your answer and give me a confidence score between 0-1 for your answer. \",\n",
    "    \"ep02\": \"This is very important to my career. \",\n",
    "    \"ep03\": \"You'd better be sure.\",\n",
    "    # add more from the paper here!!\n",
    "}\n",
    "\n",
    "# NOTE: ep06 is the combination of ep01, ep02, ep03\n",
    "emotion_stimuli_dict[\"ep06\"] = (\n",
    "    emotion_stimuli_dict[\"ep01\"]\n",
    "    + emotion_stimuli_dict[\"ep02\"]\n",
    "    + emotion_stimuli_dict[\"ep03\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a43822dc-acaa-4527-a354-09420e25b1f8",
   "metadata": {},
   "source": [
    "#### Initialize base QA Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20cee335-2b85-40fb-9a29-8d8ff7078271",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.prompts import RichPromptTemplate\n",
    "\n",
    "qa_tmpl_str = \"\"\"\\\n",
    "Context information is below. \n",
    "---------------------\n",
    "{{ context_str }}\n",
    "---------------------\n",
    "Given the context information and not prior knowledge, \\\n",
    "answer the query.\n",
    "{{ emotion_str }}\n",
    "Query: {{ query_str }}\n",
    "Answer: \\\n",
    "\"\"\"\n",
    "qa_tmpl = RichPromptTemplate(qa_tmpl_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da84429e-2da6-4c72-b903-d0acda4e064a",
   "metadata": {},
   "source": [
    "#### Prepend emotions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b5b584",
   "metadata": {},
   "outputs": [],
   "source": [
    "QA_PROMPT_KEY = \"response_synthesizer:text_qa_template\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada9f49d-42a5-4724-9313-9701cff6691b",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def run_and_evaluate(\n",
    "    query_engine, eval_qa_pairs, batch_runner, emotion_stimuli_str, qa_tmpl\n",
    "):\n",
    "    \"\"\"Run and evaluate.\"\"\"\n",
    "    new_qa_tmpl = qa_tmpl.partial_format(emotion_str=emotion_stimuli_str)\n",
    "\n",
    "    old_qa_tmpl = query_engine.get_prompts()[QA_PROMPT_KEY]\n",
    "    query_engine.update_prompts({QA_PROMPT_KEY: new_qa_tmpl})\n",
    "    avg_correctness = await get_correctness(\n",
    "        query_engine, eval_qa_pairs, batch_runner\n",
    "    )\n",
    "    query_engine.update_prompts({QA_PROMPT_KEY: old_qa_tmpl})\n",
    "    return avg_correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26cbe3c6-a134-4ce6-b093-29e8aea8a7e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 60/60 [00:17<00:00,  3.43it/s]\n",
      "100%|██████████| 60/60 [00:44<00:00,  1.34it/s]\n"
     ]
    }
   ],
   "source": [
    "# try out ep01\n",
    "correctness_ep01 = await run_and_evaluate(\n",
    "    query_engine,\n",
    "    eval_dataset.qr_pairs,\n",
    "    batch_runner,\n",
    "    emotion_stimuli_dict[\"ep01\"],\n",
    "    qa_tmpl,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590cf18b-2805-4252-9975-12e0179d8108",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.283333333333333\n"
     ]
    }
   ],
   "source": [
    "print(correctness_ep01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed68191c-fd82-4a19-b15e-d21e6f8490fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 60/60 [00:17<00:00,  3.49it/s]\n",
      "100%|██████████| 60/60 [00:46<00:00,  1.28it/s]\n"
     ]
    }
   ],
   "source": [
    "# try out ep02\n",
    "correctness_ep02 = await run_and_evaluate(\n",
    "    query_engine,\n",
    "    eval_dataset.qr_pairs,\n",
    "    batch_runner,\n",
    "    emotion_stimuli_dict[\"ep02\"],\n",
    "    qa_tmpl,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8dd8042-ee7e-4486-b975-1d063aebf306",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.466666666666667\n"
     ]
    }
   ],
   "source": [
    "print(correctness_ep02)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4857e10e-da32-48bf-a0f2-7a545660742a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 60/60 [00:12<00:00,  4.74it/s]\n",
      "100%|██████████| 60/60 [00:45<00:00,  1.32it/s]\n"
     ]
    }
   ],
   "source": [
    "# try none\n",
    "correctness_base = await run_and_evaluate(\n",
    "    query_engine, eval_dataset.qr_pairs, batch_runner, \"\", qa_tmpl\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7492cbf6-e4c9-4b43-8164-96596a24730b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.533333333333333\n"
     ]
    }
   ],
   "source": [
    "print(correctness_base)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1fd4497",
   "metadata": {},
   "source": [
    "From this, we can see that more emotional prompts seem to lead to better performance!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-caVs7DDe-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
